{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from CLEAN.infer import *\n",
    "# pre-compute esm-1b embedding for sequences in new-392 dataset\n",
    "ensure_dirs(\"data/esm_data\")\n",
    "ensure_dirs(\"data/pretrained\")\n",
    "csv_to_fasta(\"data/new.csv\", \"data/new.fasta\")\n",
    "retrive_esm1b_embedding(\"new\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of EC numbers with only one sequences: 1496\n",
      "Number of single-seq EC number sequences need to mutate:  0\n",
      "Number of single-seq EC numbers already mutated:  1496\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2702/2702 [00:00<00:00, 2786.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating distance map, number of unique EC is 2702\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2702it [00:01, 2077.21it/s]\n"
     ]
    }
   ],
   "source": [
    "from CLEAN.infer import *\n",
    "train_file = \"split10\"\n",
    "# this will mutate a sequence whose EC number only has this sequence\n",
    "# 10 mutated embedding will be saved for each such sequence\n",
    "train_fasta_file = mutate_single_seq_ECs(train_file)\n",
    "retrive_esm1b_embedding(train_fasta_file)\n",
    "\n",
    "# this will save the distance matrix and esm embedding matrix \n",
    "# (split10.pkl and split10_esm.pkl) in folder '/data/distance_map'\n",
    "compute_esm_distance(train_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([392, 128])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5242/5242 [00:00<00:00, 11050.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating eval distance map, between 392 test ids and 5242 train EC cluster centers\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "392it [00:00, 1125.91it/s]\n",
      "100%|██████████| 5242/5242 [00:00<00:00, 11203.57it/s]\n",
      "20000it [00:17, 1141.39it/s]\n",
      "100%|██████████| 392/392 [00:08<00:00, 45.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############ EC calling results using random chosen 20k samples ############\n",
      "---------------------------------------------------------------------------\n",
      ">>> total samples: 392 | total ec: 177 \n",
      ">>> precision: 0.558 | recall: 0.477| F1: 0.482 | AUC: 0.737 \n",
      "---------------------------------------------------------------------------\n",
      "The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([392, 128])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5242/5242 [00:00<00:00, 9752.09it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating eval distance map, between 392 test ids and 5242 train EC cluster centers\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "392it [00:00, 894.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############ EC calling results using maximum separation ############\n",
      "---------------------------------------------------------------------------\n",
      ">>> total samples: 392 | total ec: 177 \n",
      ">>> precision: 0.596 | recall: 0.479| F1: 0.497 | AUC: 0.739 \n",
      "---------------------------------------------------------------------------\n",
      "================================================================\n",
      "================================================================\n",
      "The embedding sizes for train and test: torch.Size([8359, 128]) torch.Size([392, 128])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2702/2702 [00:00<00:00, 13546.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating eval distance map, between 392 test ids and 2702 train EC cluster centers\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "392it [00:00, 1800.14it/s]\n",
      "100%|██████████| 2702/2702 [00:00<00:00, 13833.68it/s]\n",
      "20000it [00:11, 1739.29it/s]\n",
      "100%|██████████| 392/392 [00:02<00:00, 187.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############ EC calling results using random chosen 20k samples ############\n",
      "---------------------------------------------------------------------------\n",
      ">>> total samples: 392 | total ec: 177 \n",
      ">>> precision: 0.284 | recall: 0.254| F1: 0.257 | AUC: 0.627 \n",
      "---------------------------------------------------------------------------\n",
      "The embedding sizes for train and test: torch.Size([8359, 128]) torch.Size([392, 128])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2702/2702 [00:00<00:00, 15211.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating eval distance map, between 392 test ids and 2702 train EC cluster centers\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "392it [00:00, 1972.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############ EC calling results using maximum separation ############\n",
      "---------------------------------------------------------------------------\n",
      ">>> total samples: 392 | total ec: 177 \n",
      ">>> precision: 0.33 | recall: 0.284| F1: 0.287 | AUC: 0.642 \n",
      "---------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "from CLEAN.infer import *\n",
    "# testing on new-392 dataset with pretrained weights trained on 100% split\n",
    "infer_pvalue(\"split100\", \"new\", p_value=1e-5, nk_random=20,\n",
    "             report_metrics=True, pretrained=True)\n",
    "infer_maxsep(\"split100\", \"new\", report_metrics=True, pretrained=True)\n",
    "\n",
    "print(\"================================================================\")\n",
    "print(\"================================================================\")\n",
    "\n",
    "# testing on new-392 with a model trained on 10% split and triplet loss\n",
    "infer_pvalue(\"split10\", \"new\", p_value=1e-5, nk_random=20,\n",
    "             report_metrics=True, pretrained=False, model_name=\"split10_triplet\")\n",
    "infer_maxsep(\"split10\", \"new\", report_metrics=True, \n",
    "             pretrained=False, model_name=\"split10_triplet\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.5 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "386218770bb7053658aedbdb94aaaba888065d92b04918111f39a883f4943438"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
